Model selection¶
from IPython import get_ipython
if get_ipython():
get_ipython().run_line_magic("load_ext", "autoreload")
get_ipython().run_line_magic("autoreload", "2")
import numpy as np
import pandas as pd
import torch
import xarray as xr
import matplotlib.pyplot as plt
import seaborn as sns
import collections
import latenta as la
la.logger.setLevel("INFO")
Generative model¶
n_cells = 1000
cells = la.Dim([str(i) for i in range(n_cells)], "cell")
x1 = la.Fixed(pd.Series(np.random.uniform(0, 1, cells.size), index = cells.index), label = "x1", symbol = "x1")
x1.distribution = la.distributions.Uniform(definition = x1)
x2 = la.Fixed(pd.Series(np.random.uniform(0, 1, cells.size), index = cells.index), label = "x2", symbol = "x2")
x2.distribution = la.distributions.Uniform(definition = x2)
gene_infos = {}
gene_outputs = {}
n_genes = 20
genes = la.Dim(pd.Series([f"constant {i}" for i in range(n_genes)], name = "gene"))
gene_infos.update({gene_id:{"type":"constant"} for gene_id in genes.index})
def random_a(n_genes):
return np.random.choice([-1, 1], n_genes) * np.random.uniform(1., 2., n_genes)
def random_x(n_genes):
return np.random.uniform(0.15, 0.85, n_genes)
n_genes = 20
gene_type = "linear(x1)"
genes = la.Dim(pd.Series([f"{gene_type} {i}" for i in range(n_genes)], name = "gene"))
gene_infos.update({gene_id:{"type":gene_type} for gene_id in genes.index})
a = la.Fixed(pd.Series(random_a(n_genes), index = genes.index), label = "a")
gene_outputs[gene_type] = la.links.scalar.Linear(x1, a)
n_genes = 20
gene_type = "switch(x1)"
genes = la.Dim(pd.Series([f"{gene_type} {i}" for i in range(n_genes)], name = "gene"))
gene_infos.update({gene_id:{"type":gene_type} for gene_id in genes.index})
a = la.Fixed(pd.Series(random_a(n_genes), index = genes.index), label = "a")
switch = la.Fixed(pd.Series(random_x(n_genes), index = genes.index))
gene_outputs[gene_type] = la.links.scalar.Switch(x1, a = a, switch = switch)
n_genes = 20
gene_type = "linear(x2)"
genes = la.Dim(pd.Series([f"{gene_type} {i}" for i in range(n_genes)], name = "gene"))
gene_infos.update({gene_id:{"type":gene_type} for gene_id in genes.index})
a = la.Fixed(pd.Series(random_a(n_genes), index = genes.index), label = "a")
gene_outputs[gene_type] = la.links.scalar.Linear(x2, a)
n_genes = 20
gene_type = "spline(x1)"
genes = la.Dim(pd.Series([f"{gene_type} {i}" for i in range(n_genes)], name = "gene"))
gene_infos.update({gene_id:{"type":gene_type} for gene_id in genes.index})
knot = la.Fixed(pd.Series(np.linspace(0., 1., 10), index = pd.Series(range(10), name = "knot")))
a = la.Fixed(pd.DataFrame(np.random.rand(n_genes, knot[0].size) * 2 - 1, columns = pd.Series(range(10), name = "knot"), index = genes.index))
gene_outputs[gene_type] = la.links.scalar.Spline(x1, a = a, knot = knot, smoothness = la.Fixed(10.))
n_genes = 20
gene_type = "linear(x1) + linear(x2)"
genes = la.Dim(pd.Series([f"{gene_type} {i}" for i in range(n_genes)], name = "gene"))
gene_infos.update({gene_id:{"type":gene_type} for gene_id in genes.index})
a1 = la.Fixed(pd.Series(random_a(n_genes), index = genes.index), label = "a")
a2 = la.Fixed(pd.Series(random_a(n_genes), index = genes.index), label = "a")
gene_outputs[gene_type] = la.modular.Additive(
x1 = la.links.scalar.Linear(x1, a1),
x2 = la.links.scalar.Linear(x2, a2),
definition = la.Definition([cells, genes])
)
n_genes = 20
gene_type = "linear([x1, x2])"
genes = la.Dim(pd.Series([f"{gene_type} {i}" for i in range(n_genes)], name = "gene"))
gene_infos.update({gene_id:{"type":gene_type} for gene_id in genes.index})
a = la.Fixed(pd.Series(random_a(n_genes), index = genes.index), label = "a")
gene_outputs[gene_type] = la.links.scalars.Linear([x1, x2], a = a)
n_genes = 20
gene_type = "switch([x1, x2])"
genes = la.Dim(pd.Series([f"{gene_type} {i}" for i in range(n_genes)], name = "gene"))
gene_infos.update({gene_id:{"type":gene_type} for gene_id in genes.index})
a = la.Fixed(pd.Series(random_a(n_genes), index = genes.index), label = "a")
gene_outputs[gene_type] = la.links.scalars.Linear([
la.links.scalar.Switch(x1, switch = la.Fixed(pd.Series(random_x(n_genes), index = genes.index))),
la.links.scalar.Switch(x2, switch = la.Fixed(pd.Series(random_x(n_genes), index = genes.index)))
], a = a)
n_genes = 20
gene_type = "linear([switch(x1), x2])"
genes = la.Dim(pd.Series([f"{gene_type} {i}" for i in range(n_genes)], name = "gene"))
gene_infos.update({gene_id:{"type":gene_type} for gene_id in genes.index})
a = la.Fixed(pd.Series(random_a(n_genes), index = genes.index), label = "a")
gene_outputs[gene_type] = la.links.scalars.Linear([
la.links.scalar.Switch(x1, a = la.Fixed(1.), switch = la.Fixed(0.5), label = "X_activity"),
x2
], a = a)
n_genes = 20
gene_type = "complicated([x1, x2])"
genes.index = pd.Series([f"{gene_type} {i}" for i in range(n_genes)], name = "gene")
gene_infos.update({gene_id:{"type":gene_type} for gene_id in genes.index})
n_knots = [10, 10]
knots_value = np.dstack(np.meshgrid(np.linspace(0, 1, n_knots[0]), np.linspace(0, 1, n_knots[1]))).reshape(-1, 2)
knots_dim = la.Dim(pd.Series(range(knots_value.shape[0]), name = "knot"))
knots = [
la.Fixed(pd.Series(knots_value[:, 0], index = knots_dim.index)),
la.Fixed(pd.Series(knots_value[:, 1], index = knots_dim.index))
]
a = la.Fixed(pd.DataFrame(np.random.rand(n_genes, knots_dim.size) * 4 - 1, columns = knots_dim.index, index = genes.index))
smoothness = la.Fixed(10.)
gene_outputs[gene_type] = la.links.scalars.Thinplate(
[x1, x2], knots = knots, a = a, smoothnesses = smoothness, n_knots = n_knots
)
gene_info = pd.DataFrame.from_dict(gene_infos, orient = "index")
gene_info.index.name = "gene"
genes = la.Dim(gene_info.index)
output = la.modular.Additive(0., la.Definition([cells, genes]), label = "output", subsettable = {genes})
for component_id, component in gene_outputs.items():
setattr(output, component_id, component)
scale = la.Fixed(0.5)
dist = la.distributions.Normal(loc = output, scale = scale)
model_gs = la.Model(dist)
model_gs.plot()
<latenta.introspect.viz.component_graph.ComponentGraph at 0x7fe2234a0fa0>
posterior = la.posterior.Posterior(dist)
posterior.sample(1)
observation_value = posterior.samples[dist].sel(sample = 0).to_pandas()
fig, (ax0, ax1) = plt.subplots(1, 2, figsize = (10, 5))
cell_order = model_gs.find("x1").prior_pd().sort_values().index
sns.heatmap(observation_value.loc[cell_order], ax = ax0)
<AxesSubplot:xlabel='gene', ylabel='cell'>
output.empirical = observation_value
x1_causal = la.posterior.scalar.ScalarVectorCausal(x1, dist, observed = posterior)
x1_causal.sample(5)
x2_causal = la.posterior.scalar.ScalarVectorCausal(x2, dist, observed = posterior)
x2_causal.sample(5)
x1_x2_causal = la.posterior.scalarscalar.ScalarScalarVectorCausal(x1_causal, x2_causal)
x1_x2_causal.sample(5, n_batch = 40)
gene_ids = gene_info.groupby("type").sample(1).index
x1_causal.plot_features(feature_ids = gene_ids);
x2_causal.plot_features(feature_ids = gene_ids);
x1_x2_causal.plot_features_contour(feature_ids = gene_ids);
/home/wsaelens/projects/probabilistic-cell/latenta/src/latenta/posterior/scalar/scalarscalar/causal.py:297: UserWarning: No contour levels were found within the data range.
contour = ax.contour(data["x"]["mesh"], data["y"]["mesh"], data["z"]["mesh"], cmap=cmap)
Creating the different models¶
models = {}
Constant¶
output.reset()
mu = la.modular.Additive(intercept = la.Parameter(0., la.Definition([genes])), definition = output.value_definition, subsettable = {genes})
s = la.Parameter(1., definition = la.Definition([genes]), transforms = la.distributions.Exponential().biject_to())
mu.x1 = la.links.scalar.Constant(x1, output = mu.value_definition)
mu.x2 = la.links.scalar.Constant(x2, output = mu.value_definition)
dist = la.distributions.Normal(mu, s)
observation = la.Observation(observation_value, dist, label = "observation")
mu.empirical = observation_value
model_constant = la.Model(observation)
models["constant"] = model_constant
model_constant.plot()
<latenta.introspect.viz.component_graph.ComponentGraph at 0x7fe29a307fa0>
model_constant2 = model_constant.clone()
assert (model_constant.observation.p.loc.x1.x.loader.value is model_constant2.observation.p.loc.x1.x.loader.value)
assert not (model_constant.observation.p.scale is model_constant2.observation.p.scale)
Additive spline model¶
model = la.Model(model_constant.observation.clone())
mu = model.observation.p.loc
mu.x1 = la.links.scalar.Spline(x1, output = mu)
mu.x2 = la.links.scalar.Spline(x2, output = mu)
models["spline(x1) + spline(x2)"] = model
model.plot()
<latenta.introspect.viz.component_graph.ComponentGraph at 0x7fe28c1a00d0>
Thin plate¶
model = la.Model(model_constant.observation.clone())
mu = model.observation.p.loc
del mu.x1
del mu.x2
mu.x12 = la.links.scalars.Thinplate({"x1":x1, "x2":x2}, output = mu)
models["thinplate([x1, x2])"] = model
model.plot()
<latenta.introspect.viz.component_graph.ComponentGraph at 0x7fe40efe8d90>
Additive switch model¶
model = la.Model(model_constant.observation.clone())
x1 = model.find("x1")
x2 = model.find("x2")
mu = model.observation.p.loc
mu.x1 = la.links.scalar.Switch(x1, switch = True, a = True, output = mu.value_definition)
mu.x2 = la.links.scalar.Switch(x2, switch = True, a = True, output = mu.value_definition)
models["switch(x1) + switch(x2)"] = model
model.plot()
<latenta.introspect.viz.component_graph.ComponentGraph at 0x7fe4177d9a00>
Multiplicative switch model¶
model = la.Model(model_constant.observation.clone())
x1 = model.find("x1")
x2 = model.find("x2")
mu = model.observation.p.loc
del mu.x1
del mu.x2
mu.x12 = la.links.scalars.Linear([
la.links.scalar.Switch(x1, switch = True, output = mu),
la.links.scalar.Switch(x2, switch = True, output = mu)
], a = la.Definition([genes]))
models["switch([x1, x2])"] = model
model.plot()
<latenta.introspect.viz.component_graph.ComponentGraph at 0x7fe28c1c8ca0>
Linear additive¶
model = la.Model(model_constant.observation.clone())
x1 = model.find("x1")
x2 = model.find("x2")
mu = model.observation.p.loc
mu.x1 = la.links.scalar.Linear(x1, a = True, output = mu)
mu.x2 = la.links.scalar.Linear(x2, a = True, output = mu)
models["linear(x1) + linear(x2)"] = model
model.plot()
<latenta.introspect.viz.component_graph.ComponentGraph at 0x7fe41663e5e0>
Additive linear and interaction¶
model = la.Model(model_constant.observation.clone())
x1 = model.find("x1")
x2 = model.find("x2")
mu = model.observation.p.loc
mu.x1 = la.links.scalar.Linear(x1, a = True, output = mu)
mu.x2 = la.links.scalar.Linear(x2, a = True, output = mu)
mu.x12 = la.links.scalars.Linear([x1, x2], a = True, output = mu)
models["linear(x1) + linear(x2) + linear([x1, x2])"] = model
model.plot()
<latenta.introspect.viz.component_graph.ComponentGraph at 0x7fe416604df0>
linear(switch(x1), y)¶
model = la.Model(model_constant.observation.clone())
x1 = model.find("x1")
x2 = model.find("x2")
mu = model.observation.p.loc
del mu.x1
del mu.x2
mu.x12 = la.links.scalars.Linear([
la.links.scalar.Switch(x1, switch = True, label = "x1 activity", symbol = "x1_activity", output = mu),
x2
], a = la.Definition([genes]), output = mu)
models["linear([switch(x1), x2])"] = model
model.plot()
<latenta.introspect.viz.component_graph.ComponentGraph at 0x7fe4165e7220>
embedding¶
To illustrate why model selection is necessary, we can also create an extremely flexible embedding model that will certainly overfit. Overfitting in this context means that the embedding will contain information coming from technical noise.
Note that, in a typical use case, we would use an amortization function to infer the latent space, but this is not necessary in this case as we’re simply using a linear function on a simple dataset.
model = la.Model(model_constant.observation.clone())
mu = model.observation.p.loc
# del mu.x1
# del mu.x2
components = la.Dim(20, "component")
embedding = la.Latent(la.distributions.Normal(0., 1.), definition = la.Definition([cells, components]), label = "embedding", symbol = "embedding")
a = la.Latent(la.distributions.Normal(), definition = la.Definition([genes, components]))
mu.embedding = la.links.vector.Matmul(embedding, a)
models["embedding"] = model
model.plot()
<latenta.introspect.viz.component_graph.ComponentGraph at 0x7fe4165a0130>
for model_ix, (model_id, model) in enumerate(models.items()):
print(model_id)
# if model_id != "constant":
# continue
x1 = model.observation.find("x1")
x2 = model.observation.find("x2")
# if "trace" in model: del model["trace"]
if "trace" not in model:
inference = la.infer.svi.SVI(model, [la.infer.loss.ELBO()], la.infer.optim.Adam(lr = 0.05))
trainer = la.infer.trainer.Trainer(inference)
model["trace"] = trainer.train(1000)
inference = la.infer.svi.SVI(model, [la.infer.loss.ELBO()], la.infer.optim.Adam(lr = 0.01))
trainer = la.infer.trainer.Trainer(inference)
model["trace"] = trainer.train(1000)
# if "observed" in model: del model["observed"]
if "observed" not in model:
model["observed"] = la.posterior.vector.VectorObserved(model.observation)
model["observed"].sample(5)
# if "x1_causal" in model: del model["x1_causal"]
if "x1_causal" not in model:
x1_causal = la.posterior.scalar.ScalarVectorCausal(x1, model.observation, observed = model["observed"])
x1_causal.sample(20)
x1_causal.sample_random(5)
model["x1_causal"] = x1_causal
# if "x2_causal" in model: del model["x2_causal"]
if "x2_causal" not in model:
x2_causal = la.posterior.scalar.ScalarVectorCausal(x2, model.observation, observed = model["observed"])
x2_causal.sample(20)
x2_causal.sample_random(5)
model["x2_causal"] = x2_causal
# if "x1_x2_causal" in model: del model["x1_x2_causal"]
if "x1_x2_causal" not in model:
x1_x2_causal = la.posterior.scalarscalar.ScalarScalarVectorCausal(model["x1_causal"], model["x2_causal"])
x1_x2_causal.sample(20, n_batch = 20)
model["x1_x2_causal"] = x1_x2_causal
constant
spline(x1) + spline(x2)
thinplate([x1, x2])
switch(x1) + switch(x2)
switch([x1, x2])
switch(x1, s1) + switch(x2, s2) + switch([x1, x2], [s1, s2])
linear(x1) + linear(x2)
linear(x1) + linear(x2) + linear([x1, x2])
linear([switch(x1), x2])
embedding
Explore a model
# model = models["linear(x1) + linear(x2) + linear([x1, x2])"]
# model = models["thinplate([x1, x2])"]
# model = models["constant"]
# model = models["switch(x1) + switch(x2)"]
# model = models["switch([x1, x2])"]
# model = models["spline(x1) + spline(x2)"]
model = models["linear([switch(x1), x2])"]
# model = models["embedding"]
model["x1_causal"].plot_features();
model["x2_causal"].plot_features();
model["x1_x2_causal"].plot_features_contour(feature_ids = gene_ids);
model["x1_x2_causal"].plot_features_contour(feature_ids = gene_info.query("type == 'linear([switch(x1), x2])'").index[:5]);
fig = model["x1_x2_causal"].plot_likelihood_ratio();
fig.axes[0].legend(bbox_to_anchor=(0.5, 1.1), ncol = 2, title = "coregulatory")
<matplotlib.legend.Legend at 0x7fe41ce0ffa0>
ax = sns.scatterplot(x = "lr_x1", y = "lr_x2", data = model["x1_x2_causal"].scores.join(gene_info), hue = "type")
ax.legend(bbox_to_anchor=(1.1, 1.05))
ax.set_xscale("sigmoid")
ax.set_yscale("sigmoid")
model["x1_x2_causal"].plot_features_contour(
# feature_ids = gene_info.query("type == 'switch(x1)'").index
);
Model selection¶
model = models["switch(x1) + switch(x2)"]
model = models["switch([x1, x2])"]
# model = models["thinplate([x1, x2])"]
# model["observed"].sample(1)
for model_id, model in models.items():
print(model_id, model["observed"].elbo.mean().item())
constant -199119.109375
spline(x1) + spline(x2) -157483.01318359375
thinplate([x1, x2]) -159334.84765625
switch(x1) + switch(x2) -160374.72985839844
switch([x1, x2]) -160205.75830078125
switch(x1, s1) + switch(x2, s2) + switch([x1, x2], [s1, s2]) -160354.8472623825
linear(x1) + linear(x2) -161801.66522216797
linear(x1) + linear(x2) + linear([x1, x2]) -159300.88177490234
linear([switch(x1), x2]) -168151.75451660156
embedding -171802.8017578125
likelihoods = xr.concat([model["observed"].likelihood_features for model in models.values()], dim = pd.Series(models.keys(), name = "model")).to_pandas()
model_ids = [model_id for model_id in likelihoods.index if model_id not in ["embedding"]]
likelihoods = likelihoods.loc[model_ids]
evidences = xr.concat([model["observed"].elbo_features for model in models.values()], dim = pd.Series(models.keys(), name = "model")).to_pandas()
evidences = evidences.loc[model_ids]
sns.heatmap((likelihoods == likelihoods.max(0)).T)
<AxesSubplot:xlabel='model', ylabel='gene'>
sns.heatmap((likelihoods - likelihoods.max(0)).T)
<AxesSubplot:xlabel='model', ylabel='gene'>
sns.heatmap((likelihoods.max(0) - likelihoods).T < 5)
<AxesSubplot:xlabel='model', ylabel='gene'>
# evidences = evidences.loc[[i for i in evidences.index if not i.startswith("thinpl")]]
sns.heatmap((evidences).T)
<AxesSubplot:xlabel='model', ylabel='gene'>
sns.heatmap((evidences.max(0) - evidences).T < 5)
<AxesSubplot:xlabel='model', ylabel='gene'>
sns.heatmap((evidences == evidences.max()).T)
<AxesSubplot:xlabel='model', ylabel='gene'>
selected_evidence = (evidences == evidences.max())
likelihood_diff = (likelihoods - (likelihoods * selected_evidence).min())
undermodelled = (likelihood_diff > 1).any()
gene_info["undermodelled"] = undermodelled
gene_info["selected_evidence"] = selected_evidence.idxmax(0)
gene_info.groupby(["type", "undermodelled"]).count()["selected_evidence"].unstack().T.fillna(0.).style.background_gradient(cmap=sns.color_palette("Blues", as_cmap=True))
| type | complicated([x1, x2]) | constant | linear([switch(x1), x2]) | linear([x1, x2]) | linear(x1) | linear(x1) + linear(x2) | linear(x2) | spline(x1) | switch([x1, x2]) | switch(x1) |
|---|---|---|---|---|---|---|---|---|---|---|
| undermodelled | ||||||||||
| False | 9.000000 | 3.000000 | 20.000000 | 15.000000 | 12.000000 | 13.000000 | 10.000000 | 16.000000 | 20.000000 | 18.000000 |
| True | 11.000000 | 17.000000 | 0.000000 | 5.000000 | 8.000000 | 7.000000 | 10.000000 | 4.000000 | 0.000000 | 2.000000 |
pd.crosstab(gene_info["selected_evidence"], gene_info["type"]).style.background_gradient(cmap=sns.color_palette("Blues", as_cmap=True))
| type | complicated([x1, x2]) | constant | linear([switch(x1), x2]) | linear([x1, x2]) | linear(x1) | linear(x1) + linear(x2) | linear(x2) | spline(x1) | switch([x1, x2]) | switch(x1) |
|---|---|---|---|---|---|---|---|---|---|---|
| selected_evidence | ||||||||||
| constant | 0 | 20 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
| linear([switch(x1), x2]) | 0 | 0 | 20 | 1 | 0 | 0 | 0 | 0 | 0 | 0 |
| linear(x1) + linear(x2) | 0 | 0 | 0 | 0 | 19 | 20 | 20 | 0 | 0 | 0 |
| linear(x1) + linear(x2) + linear([x1, x2]) | 0 | 0 | 0 | 19 | 1 | 0 | 0 | 0 | 0 | 0 |
| spline(x1) + spline(x2) | 8 | 0 | 0 | 0 | 0 | 0 | 0 | 16 | 0 | 0 |
| switch([x1, x2]) | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 20 | 0 |
| switch(x1) + switch(x2) | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 2 | 0 | 13 |
| switch(x1, s1) + switch(x2, s2) + switch([x1, x2], [s1, s2]) | 2 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 7 |
| thinplate([x1, x2]) | 8 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
sns.stripplot(likelihood_diff.max(), y = gene_info["type"])
/home/wsaelens/projects/probabilistic-cell/.venv/lib/python3.8/site-packages/seaborn/_decorators.py:36: FutureWarning: Pass the following variable as a keyword arg: x. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation.
warnings.warn(
<AxesSubplot:ylabel='type'>